import data_generation
import create_networks
import smooth_dp_utils
import itertools
import torch
import numpy as np
import os
import pickle


def check_or_create_folder(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        return f"Folder '{folder_name}' created."
    else:
        return f"Folder '{folder_name}' already exists."


def create_graph(M, sparsity, seed_n):
    bin_M = torch.tensor(create_networks.get_bin_M(M=M, sparsity=sparsity))
    prior_M = torch.tensor(create_networks.get_adj_cost_M(bin_m = bin_M), dtype=torch.float32)
    E = int(bin_M.sum().detach().numpy())
    M_indices = np.zeros((E, 2))
    M_indices[:,0], M_indices[:,1] = np.where(bin_M.detach().numpy()==1)
    M_indices = torch.tensor(M_indices, dtype=torch.long)
    return bin_M, prior_M, E, M_indices 
    
    
def generate_synthetic_data(N_train, N_val, N_test, noise_data, E, M_indices, prior_M, seed_n):
    X, dY, _ = data_generation.gen_data(N_train, E, nl=noise_data, seed_number=seed_n, samples_dist=1)
    X_val, dY_val, _ = data_generation.gen_data(N_val, E, nl=noise_data, seed_number=seed_n+100, samples_dist=1)
    X_test, dY_test, _ = data_generation.gen_data(N_test, E, nl=noise_data, seed_number=seed_n+200, samples_dist=1)

    X = torch.tensor(X, dtype=torch.float32)
    dY = torch.tensor(dY, dtype=torch.float32)

    X_val = torch.tensor(X_val, dtype=torch.float32)
    dY_val = torch.tensor(dY_val, dtype=torch.float32)

    X_test = torch.tensor(X_test, dtype=torch.float32)
    dY_test = torch.tensor(dY_test, dtype=torch.float32)
    
    M_Y = costs_to_matrix(prior_M, M_indices, dY, pred=False)
    M_Y_val = costs_to_matrix(prior_M, M_indices, dY_val, pred=False)
    M_Y_test = costs_to_matrix(prior_M, M_indices, dY_test, pred=False)
    
    return X, X_val, X_test, dY, dY_val, dY_test, M_Y, M_Y_val, M_Y_test


def costs_to_matrix(prior, M_indices, dY, pred=True):
    N = dY.shape[0] #Batch
    Mat = prior.unsqueeze(0).expand((N, prior.shape[0], prior.shape[1])).clone()
    for n, (i, j) in enumerate(zip(M_indices[:,0], M_indices[:,1])):
        Mat[:, int(i), int(j)] = (prior[int(i), int(j)]).unsqueeze(0) + dY[:,n]
    return Mat.clamp(0.001, None)


def source_end_nodes_permutation(M, perc_end_nodes_seen):
    all_permutations = list(itertools.permutations(range(M), 2))
    filtered_permutations = [perm for perm in all_permutations if perm[0] < perm[1]]
    size_seen = int(perc_end_nodes_seen*len(filtered_permutations))
    seen_indices = np.random.choice(len(filtered_permutations), size_seen, replace=False)
    unseen_indices = np.array(list(set(np.arange(0,len(filtered_permutations))) - set(seen_indices)))
    seen_permutations = [filtered_permutations[i] for i in seen_indices]
    unseen_permutations = [filtered_permutations[i] for i in unseen_indices]
    return seen_permutations, unseen_permutations


def gen_source_end_nodes_train(seen_permutations, N_train):
    end_to_end_nodes_train = np.zeros((N_train, 2))
    for i in range(0, N_train):
        random_index = np.random.choice(len(seen_permutations))
        idx = seen_permutations[random_index]
        end_to_end_nodes_train[i, :] = idx
    end_to_end_nodes_train = end_to_end_nodes_train.astype(int)
    return end_to_end_nodes_train


def gen_paths(end_to_end_nodes_train, N_train, M_Y, BBB=50):
    paths_demonstration_train_ = []
    for i in range(0, N_train//BBB):
        paths_demonstration_train_.append(
            smooth_dp_utils.batch_dijkstra(M_Y[i*BBB:(i+1)*BBB], end_to_end_nodes_train[i*BBB:(i+1)*BBB]))
    paths_demonstration_train = [item for sublist in paths_demonstration_train_ for item in sublist]
    return paths_demonstration_train


def process_paths(paths_demonstration_train, nodes_in_cluster_sorted, M_indices,
    seed_n, M, sparsity, noise_data, perc_end_nodes_seen, train=True):
    
    prefix_train = ''
    if not train:
        prefix_train='_val'
    
    file_data_process = f'./data_synthetic_gen/{seed_n}_{M}_{sparsity}_{noise_data}_{perc_end_nodes_seen}_{prefix_train}.pkl'

    if not os.path.exists(file_data_process):

        node_idx_sequence_trips = []
        edges_seq_original = []
        edges_idx_on_original = []
        start_nodes_original = []
        end_nodes_original = []

        for idx in range(0, len(paths_demonstration_train)):
            node_sequence_trip = paths_demonstration_train[idx]
            node_idx_sequence_trip = np.searchsorted(nodes_in_cluster_sorted, node_sequence_trip)
            node_idx_sequence_trips.append(node_idx_sequence_trip)
            start_nodes_original.append(node_idx_sequence_trip[0])
            end_nodes_original.append(node_idx_sequence_trip[-1])
            edges_sequence_trip = np.column_stack([node_idx_sequence_trip[:-1], node_idx_sequence_trip[1:]])
            edges_seq_original.append(edges_sequence_trip)
            edges_idx_sequence_trip = np.array(
                [1 if any(np.array_equal(edge, t) for t in edges_sequence_trip) else 0 for edge in M_indices])
            edges_idx_on_original.append(edges_idx_sequence_trip)  

        data = {
            "node_idx_sequence_trips": node_idx_sequence_trips,
            "edges_seq_original": edges_seq_original,
            "edges_idx_on_original": edges_idx_on_original,
            "start_nodes_original": start_nodes_original,
            "end_nodes_original": end_nodes_original
        }

        _ = check_or_create_folder('data_synthetic_gen')

        with open(file_data_process, 'wb') as file:
            pickle.dump(data, file)
        print("Data saved to", file_data_process)
    else:
        with open(file_data_process, 'rb') as file:
            data = pickle.load(file)
        
    return data
    

def combined_distance(sample, data):
    d1 = (data[:, 0] - sample[0]).abs()
    d2 = (data[:, 1] - sample[1]).abs()
    d3 = (data[:, 2] - sample[2]).abs()
    total_dist = (d1+d2+d3)/3
    return total_dist

def find_k_similar_indices(data, k):
    idx = torch.randint(0, len(data), (1,))
    sample = data[idx.item()]
    distances = combined_distance(sample, data)
    distances[idx] = float('inf')
    k_indices = distances.topk(k, largest=False)[1]
    all_indices = torch.cat((idx, k_indices))
    return all_indices

def generate_n_combinations(data, k, n):
    all_indices = [find_k_similar_indices(data, k) for _ in range(n)]
    return torch.stack(all_indices)


    
def edges_on_pred(best_pred_path, M_indices):
    edges_on_pred = np.zeros((len(best_pred_path), M_indices.shape[0]))
    for i in range(0, len(best_pred_path)):
        edges_sequence_pred_i = np.column_stack([best_pred_path[i][:-1], best_pred_path[i][1:]])
        edges_on_pred[i] = np.array(
            [1 if any(np.array_equal(edge, t) for t in edges_sequence_pred_i) else 0 for edge in M_indices.detach().numpy()])
    return edges_on_pred


def eval_dijkstra(N_eval, n_paths):
    bs = 20
    with torch.no_grad():
        dY_, dsigmaY_ = model(X)
        sigmaY_ = dsigmaY_ + 0.2*prior_M[M_indices[:,0], M_indices[:,1]]

        M_Y_pred_ = dY_to_M_eval(dY_[:N_eval, :], N_eval, prior_M)
        adj_main = M_Y_pred_
        
        M_Y_pred_sigma_ = dY_to_M_eval(sigmaY_[:N_eval, :], N_eval, 0.*prior_M)

        M_Y_pred_probs = M_Y_pred_.unsqueeze(0).repeat(n_paths,1,1,1)
        M_Y_pred_sig_probs = M_Y_pred_sigma_.unsqueeze(0).repeat(n_paths,1,1,1)
        adjs_probs = (M_Y_pred_probs + M_Y_pred_sigma_*torch.randn_like(M_Y_pred_probs)).clamp(0.001)

        best_pred_path_ = []
        for b in tqdm(range(0, N_eval//bs)):
            best_pred_path_.append(
                smooth_dp_utils.batch_dijkstra(adj_main[b*bs:(b+1)*bs], 
                                               end_to_end_nodes_original[b*bs:(b+1)*bs])
            )        
        best_pred_path = [item for sublist in best_pred_path_ for item in sublist]
        
        edges_on_pred_np = edges_on_pred(best_pred_path, M_indices)
        met = np.zeros((N_eval,6))
        for i in range(0, N_eval):
            met[i] = compute_metrics_percentage(edges_idx_on_original[i], edges_on_pred_np[i])

        met_probs = np.zeros((n_paths, N_eval,6))
        for p in range(0, n_paths):
            
            best_pred_probs_path_ = []
            for b in tqdm(range(0, N_eval//bs)):
                best_pred_probs_path_.append(
                    smooth_dp_utils.batch_dijkstra(adjs_probs[p,b*bs:(b+1)*bs], 
                                                   end_to_end_nodes_original[b*bs:(b+1)*bs])
                )        
            best_pred_probs_path = [item for sublist in best_pred_probs_path_ for item in sublist]
            
            
            edges_on_pred_probs = edges_on_pred(best_pred_probs_path, M_indices)
            for i in range(0, N_eval):
                met_probs[p,i] = compute_metrics_percentage(edges_idx_on_original[i], edges_on_pred_probs[i])
    
    return met.mean(0), met_probs[:,:,-1].max(0).mean()    


from collections import Counter
from collections import deque

def get_nodes_and_freqs(node_idx_sequence_trips):
    flattened_trips = [item for sublist in node_idx_sequence_trips for item in sublist]
    node_count_freq = Counter(flattened_trips)
    elements, frequencies = zip(*node_count_freq.items())
    return np.array(elements), np.array(frequencies)

def find_close_nodes(edge_tensor, start_node, num_nodes):
    # Convert the edge tensor to an adjacency list
    graph = {}
    for edge in edge_tensor:
        #import pdb
        #pdb.set_trace()
        node_from, node_to = edge.tolist()
        if node_from not in graph:
            graph[node_from] = []
        if node_to not in graph:
            graph[node_to] = []
        graph[node_from].append(node_to)
        graph[node_to].append(node_from)

    # Breadth-first search
    visited = set()
    queue = deque([start_node])
    while queue and len(visited) < num_nodes:
        current_node = queue.popleft()
        if current_node not in visited:
            visited.add(current_node)
            if current_node in graph:
                for neighbor in graph[current_node]:
                    if neighbor not in visited:
                        queue.append(neighbor)

    return list(visited)


from multiprocessing import Pool

def process_chunk(chunk_data, nodes_selected_set, map_nodes_selected):
    chunk, start_idx = chunk_data
    new_chunk = []
    indices = []
    for idx, inner_list in enumerate(chunk, start=start_idx):
        filtered_and_replaced_list = [map_nodes_selected[element] for element in inner_list if element in nodes_selected_set]
        if len(filtered_and_replaced_list) >= 2:
            new_chunk.append(filtered_and_replaced_list)
            indices.append(idx)
    return new_chunk, indices

def selected_trips_and_idx(node_idx_sequence_trips, M_indices, elements, frequencies, Vs, V):

    shuffled_indices = torch.randperm(M_indices.size(0))
    M_indices_shuf = M_indices[shuffled_indices]
    nodes_group = find_close_nodes(M_indices_shuf, np.random.choice(elements, 1).item(), Vs//2)
    
    elements_rest = list(set(elements)- set(nodes_group))
    args_rest = np.array([np.where(np.array(elements)==el)[0].squeeze().item() for el in elements_rest])
    
    nodes_selected = np.random.choice(
        elements[args_rest], size=Vs-len(nodes_group), 
        p=np.array(frequencies[args_rest])/sum(frequencies[args_rest]), replace=False)
    
    nodes_selected = np.array(list(set(list(nodes_group)).union(nodes_selected)))

    if len(nodes_selected)<Vs:
        print('Did not select enough nodes in this batch. Might lead to bad results!')
        return None, None, None, None
    
    nodes_selected = torch.tensor(nodes_selected)

    nodes_excluded = torch.tensor(np.array(list(set(np.arange(0, V)) - set(nodes_selected.numpy()))))

    nodes_selected, _ = nodes_selected.sort()
    nodes_excluded, _ = nodes_excluded.sort()

    map_nodes_selected = dict(zip(nodes_selected.numpy(), np.arange(0, Vs)))

    nodes_selected_set = set(nodes_selected.numpy())

    # Split the list into chunks for parallel processing
    chunk_size = (len(node_idx_sequence_trips) // 16) + 1  # Example for 4 chunks
    chunks = [(node_idx_sequence_trips[i:i + chunk_size], i) for i in range(0, len(node_idx_sequence_trips), chunk_size)]

    # Use multiprocessing to process the chunks in parallel
    with Pool() as pool:
        # Pass nodes_selected_set and map_nodes_selected as additional arguments
        results = pool.starmap(process_chunk, [(chunk, nodes_selected_set, map_nodes_selected) for chunk in chunks])

    # Combine the results
    selected_trips = []
    selected_indexes = []
    for new_list, indices in results:
        selected_trips.extend(new_list)
        selected_indexes.extend(indices)

            
    return selected_indexes, selected_trips, nodes_selected, nodes_excluded



def select_Ms_from_selected_idx_and_trips(M_Y_pred, M_sigmaY, Vs, M_indices, nodes_excluded, nodes_selected, beta, dev):
    M_Y_pred_new = M_Y_pred.clone()

    for n in nodes_excluded:

        M_indices_selected_mapped = torch.argwhere((M_Y_pred_new<2000).sum(0)>0)
        mask_ind = (M_indices_selected_mapped == n)

        n_to_node = M_indices_selected_mapped[mask_ind[:,1]][:,0]
        node_to_n = M_indices_selected_mapped[mask_ind[:,0]][:,1]

        M_indices_to_check = torch.cartesian_prod(n_to_node, node_to_n)      

        M_Y_pred_new = smooth_dp_utils.remove_node_and_adjust_vectorized(
            M_Y_pred_new, n, M_indices_to_check, beta)

    idx_combinations = torch.cartesian_prod(nodes_selected, nodes_selected, nodes_selected)
    idx_combinations_2 = torch.cartesian_prod(nodes_selected, nodes_selected)

    M_Y_pred_selected = M_Y_pred_new[:, idx_combinations_2[:,0], idx_combinations_2[:,1]].clone()
    M_sigmaY_selected = M_sigmaY[:, idx_combinations_2[:,0], idx_combinations_2[:,1]].clone()

    M_Y_pred_selected = M_Y_pred_selected.reshape(M_Y_pred.shape[0], Vs, Vs)
    M_sigmaY_selected = M_sigmaY_selected.reshape(M_sigmaY.shape[0], Vs, Vs)

    M_indices_selected_mapped = torch.argwhere((M_Y_pred_selected<2000).sum(0)>0)
    M_indices_selected = M_indices[torch.isin(M_indices, nodes_selected.to(dev)).sum(1) == 2]

    return M_Y_pred_selected, M_sigmaY_selected, M_indices_selected_mapped




class CustomQueue:
    def __init__(self, V, probs_sample):
        self.queue = []
        self.V = V
        self.probs = probs_sample

    def insert(self, index, value):
        if 0 <= index <= len(self.queue):
            self.queue.insert(index, value)
        else:
            print("Index out of range")

    def remove_duplicates(self, value):
        if self.queue.count(value) > 1:
            first_index = self.queue.index(value)
            last_index = len(self.queue) - 1 - self.queue[::-1].index(value)
            del self.queue[first_index:last_index]
            return False
        return True

    def insert_between_values(self, val1, val2, value):
        try:
            index1 = self.queue.index(val1)
            index2 = self.queue.index(val2)
            if index1 < index2 and index2 - index1 == 1:
                self.insert(index2, value)
                flag = self.remove_duplicates(value)
            else:
                print("Values are not adjacent")
        except ValueError:
            print("One or both values not found in the queue")

    def remove(self):
        if self.queue:
            return self.queue.pop(0)
        else:
            print("Queue is empty")

    def display(self):
        print(self.queue)
           
    def insert_queue(self, a, b):    
        hi = np.random.choice(np.arange(0, self.V), 1, p=self.probs[a,b]).item()
        if hi!=a:
            self.insert_between_values(a, b, hi)
            if self.remove_duplicates(hi):
                self.insert_queue(a, hi)
                self.insert_queue(hi, b)